import torch
import torch.nn.functional as F

# 示例数据
prediction = torch.tensor([[0.0, 1.0, 0.0, 0.0, 0.0],
                           [1.0, 0.0, 0.0, 0.0, 0.0]])  # [2, 5]
target = torch.tensor([[1],
                       [0]])  # [2, 1]

# 调整 target 大小
target = target.squeeze(dim=1)  # [2]

# 计算交叉熵损失
loss = F.cross_entropy(prediction, target)

print(loss)
